{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "1fgVWTMK9SNz",
      "metadata": {
        "id": "1fgVWTMK9SNz"
      },
      "source": [
        "~~~\n",
        "Copyright 2025 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "~~~\n",
        "\n",
        "# Quick start with Hugging Face\n",
        "\n",
        "<table><tbody><tr>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/TxGemma/[TxGemma]Quickstart_with_Hugging_Face.ipynb\">\n",
        "      <img alt=\"Google Colab logo\" src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" width=\"32px\"><br> Run in Google Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DQuickstart_with_Hugging_Face.ipynb\">\n",
        "      <img alt=\"GitHub logo\" src=\"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\" width=\"32px\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://huggingface.co/collections/google/txgemma-release-67dd92e931c857d15e4d1e87\">\n",
        "      <img alt=\"Hugging Face logo\" src=\"https://huggingface.co/front/assets/huggingface_logo-noborder.svg\" width=\"32px\"><br> View on Hugging Face\n",
        "    </a>\n",
        "  </td>\n",
        "</tr></tbody></table>\n",
        "\n",
        "This notebook provides a basic demo of using TxGemma, a collection of large language models built upon Gemma 2, that generates predictions, classifications or text based on therapeutic related data. It contains standalone usage examples for both:\n",
        "\n",
        "- Predictive tasks, which expect a narrow form of prompting\n",
        "- Conversational use (for TxGemma-Chat model variants), which is more flexible and includes multi-turn interactions\n",
        "\n",
        "Learn more about the model at [this page](https://developers.google.com/health-ai-developer-foundations/txgemma)."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "t9xt2XZgaaH2",
      "metadata": {
        "id": "t9xt2XZgaaH2"
      },
      "source": [
        "## Setup\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the TxGemma model. Choose an appropriate runtime when starting your Colab session.\n",
        "\n",
        "You can try out TxGemma 2B or 9B* for free using a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "*To run the demo with TxGemma 9B on a T4 GPU, use int8 quantization to reduce memory usage and speed up inference. Note that the performance of quantized versions has not been evaluated."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "L9ITcQtdal7J",
      "metadata": {
        "id": "L9ITcQtdal7J"
      },
      "source": [
        "### Get access to TxGemma\n",
        "\n",
        "Before you get started, make sure that you have access to TxGemma models on Hugging Face:\n",
        "\n",
        "1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).\n",
        "2. Head over to the [TxGemma model page](https://huggingface.co/google/txgemma-2b-predict) and accept the usage conditions."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "qRFQnPL2a9Dj",
      "metadata": {
        "id": "qRFQnPL2a9Dj"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Generate a Hugging Face `read` access token by clicking [here](https://huggingface.co/settings/tokens) and add your access token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SJRgvl-Wh_VM",
      "metadata": {
        "id": "SJRgvl-Wh_VM"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "qocWBSYmb0MA",
      "metadata": {
        "id": "qocWBSYmb0MA"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "JyPXnIjML6go",
      "metadata": {
        "id": "JyPXnIjML6go"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m354.7/354.7 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.0/76.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m42.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m29.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m21.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m29.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m60.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "! pip install --upgrade --quiet accelerate bitsandbytes huggingface_hub transformers"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4KUYpBH1cZA1",
      "metadata": {
        "id": "4KUYpBH1cZA1"
      },
      "source": [
        "## Load model from Hugging Face Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "kTNyftcqF3YO",
      "metadata": {
        "id": "kTNyftcqF3YO"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0045305614a540ca9c139011601258b9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "85a4b07a2bf846c3bbfdafbead2908cf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b11b572b0bf442d89bd4cb20770a3fde",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fb3689f5cd44423ca4fbaa8f888fa18b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bd781342ce6e4fc8a9300b6520175681",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/852 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "25164fe6191d4b3b9619b30623ffc3e1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b74496af93e04b0d9edc68463506ea03",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4e84efdfa53a4f24a98e03021689f29d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "01f9e7bcfb804789adc104705ab12704",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1bf2f0668d6d44499959abcdeb609fd5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4146a2a7c57848baa7e1fca5c252d9f3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c94afce36b4240178f84c900bfb25543",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a5342199871f478fb252188a50d25cb1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "MODEL_VARIANT = \"9b-chat\"  # @param [\"2b-predict\", \"9b-chat\", \"9b-predict\", \"27b-chat\", \"27b-predict\"]\n",
        "\n",
        "model_id = f\"google/txgemma-{MODEL_VARIANT}\"\n",
        "\n",
        "if MODEL_VARIANT == \"2b-predict\":\n",
        "    additional_args = {}\n",
        "else:\n",
        "    additional_args = {\n",
        "        \"quantization_config\": BitsAndBytesConfig(load_in_8bit=True)\n",
        "    }\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    device_map=\"auto\",\n",
        "    **additional_args,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "KLpNKp8RFwoe",
      "metadata": {
        "id": "KLpNKp8RFwoe"
      },
      "source": [
        "After loading, you can directly use the model and tokenizer, which gives you complete control over the inference process, including tokenization and postprocessing of outputs.\n",
        "\n",
        "Alternatively, you can use the [`pipeline`](https://www.google.com/url?q=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftransformers%2Fen%2Fmain_classes%2Fpipelines) API, which provides a simple way to use the model for inference while abstracting away complex details. Here, instantiate a text generation pipeline using the loaded model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "UHzWTsym7QIp",
      "metadata": {
        "id": "UHzWTsym7QIp"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Device set to use cuda:0\n"
          ]
        }
      ],
      "source": [
        "from transformers import pipeline\n",
        "\n",
        "pipe = pipeline(\n",
        "    \"text-generation\",\n",
        "    model=model,\n",
        "    tokenizer=tokenizer,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "VzkBsjIiMIll",
      "metadata": {
        "id": "VzkBsjIiMIll"
      },
      "source": [
        "The following sections include standalone examples demonstating how to use the model both directly and with the `pipeline` API. In practice, you should select the method that is best suited for your use case."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dqX-SPryB97S",
      "metadata": {
        "id": "dqX-SPryB97S"
      },
      "source": [
        "## Format prompts for therapeutic tasks\n",
        "\n",
        "The following sections in this notebook demonstrate prompting TxGemma for therapeutic development tasks from [Therapeutics Data Commons](https://tdcommons.ai/) (TDC).\n",
        "\n",
        "For these predictive tasks, prompts should be formatted according to the TDC structure, including:\n",
        "\n",
        "- **Instruction:** Briefly describes the task.\n",
        "\n",
        "- **Context:** Provides 2-3 sentences of relevant biochemical background, derived from TDC descriptions and literature.\n",
        "\n",
        "- **Question:** Queries a specific therapeutic property, incorporating textual representations of therapeutics and/or targets as inputs.\n",
        "\n",
        "  - Inputs can include SMILES strings, amino acid sequences, nucleotide sequences, and natural language text.\n",
        "\n",
        "  - Optional few-shot examples can be provided."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "_z8YZ8pT1QD5",
      "metadata": {
        "id": "_z8YZ8pT1QD5"
      },
      "source": [
        "### Load prompt template\n",
        "\n",
        "First, load a JSON file that contains the prompt format for various TDC tasks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tUhpMxJq1yNi",
      "metadata": {
        "id": "tUhpMxJq1yNi"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "95d2d35a900d4c3aa8c459a4395fca43",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tdc_prompts.json:   0%|          | 0.00/768k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import json\n",
        "from huggingface_hub import hf_hub_download\n",
        "\n",
        "tdc_prompts_filepath = hf_hub_download(\n",
        "    repo_id=model_id,\n",
        "    filename=\"tdc_prompts.json\",\n",
        ")\n",
        "\n",
        "with open(tdc_prompts_filepath, \"r\") as f:\n",
        "    tdc_prompts_json = json.load(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Ih7ZOg10CNrt",
      "metadata": {
        "id": "Ih7ZOg10CNrt"
      },
      "source": [
        "### Prepare sample prompt\n",
        "\n",
        "Construct a prompt using the template and an input drug SMILES string from the [BBB (Blood-Brain Barrier), Martins et al.](https://tdcommons.ai/single_pred_tasks/adme#bbb-blood-brain-barrier-martins-et-al) dataset. This prompt will be used for generating predictions in the next sections.\n",
        "\n",
        "**Note:** The prompt should not be modified from the template except for replacing the inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eWBtd87BCQws",
      "metadata": {
        "id": "eWBtd87BCQws"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Formatted prompt:\n",
            "\n",
            "Instructions: Answer the following question about drug properties.\n",
            "Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.\n",
            "Question: Given a drug SMILES string, predict whether it\n",
            "(A) does not cross the BBB (B) crosses the BBB\n",
            "Drug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\n",
            "Answer:\n"
          ]
        }
      ],
      "source": [
        "# Set example task and input\n",
        "task_name = \"BBB_Martins\"\n",
        "input_type = \"{Drug SMILES}\"\n",
        "drug_smiles = \"CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\"\n",
        "\n",
        "TDC_PROMPT = tdc_prompts_json[task_name].replace(input_type, drug_smiles)\n",
        "print(\"Formatted prompt:\\n\")\n",
        "print(TDC_PROMPT)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "QmkhDkBicAZV",
      "metadata": {
        "id": "QmkhDkBicAZV"
      },
      "source": [
        "## Explore predictive capabilities\n",
        "\n",
        "TxGemma models are designed to process and understand information related to various therapeutic modalities and targets, including small molecules, proteins, nucleic acids, diseases, and cell lines, and can generate predictions on a broad set of therapeutic development tasks."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "oBmbHe2Hg4Y_",
      "metadata": {
        "id": "oBmbHe2Hg4Y_"
      },
      "source": [
        "### Run inference on a therapeutic task\n",
        "\n",
        "This section demonstrates prompting TxGemma for a predictive task from TDC.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ErVQz-9bNQKC",
      "metadata": {
        "id": "ErVQz-9bNQKC"
      },
      "source": [
        "**Run the model directly**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "UgEN-mgbO6Gm",
      "metadata": {
        "id": "UgEN-mgbO6Gm"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instructions: Answer the following question about drug properties.\n",
            "Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.\n",
            "Question: Given a drug SMILES string, predict whether it\n",
            "(A) does not cross the BBB (B) crosses the BBB\n",
            "Drug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\n",
            "Answer: (B)\n"
          ]
        }
      ],
      "source": [
        "# Use sample prompt for a predictive task from TDC\n",
        "prompt = TDC_PROMPT\n",
        "\n",
        "# Prepare tokenized inputs\n",
        "input_ids = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "# Generate response\n",
        "outputs = model.generate(**input_ids, max_new_tokens=8)\n",
        "response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Vzlmxr1kNXGU",
      "metadata": {
        "id": "Vzlmxr1kNXGU"
      },
      "source": [
        "**Run with the `pipeline` API**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VTjH98DyNfbN",
      "metadata": {
        "id": "VTjH98DyNfbN"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instructions: Answer the following question about drug properties.\n",
            "Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.\n",
            "Question: Given a drug SMILES string, predict whether it\n",
            "(A) does not cross the BBB (B) crosses the BBB\n",
            "Drug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\n",
            "Answer: (B)\n"
          ]
        }
      ],
      "source": [
        "# Use sample prompt for a predictive task from TDC\n",
        "prompt = TDC_PROMPT\n",
        "\n",
        "# Generate response\n",
        "outputs = pipe(prompt, max_new_tokens=8)\n",
        "response = outputs[0][\"generated_text\"]\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "V4WlXpx2fGN4",
      "metadata": {
        "id": "V4WlXpx2fGN4"
      },
      "source": [
        "## Explore conversational capabilities with TxGemma-Chat\n",
        "\n",
        "TxGemma features conversational models that add reasoning and explainability to predictions and can be used in multi-turn interactions. Their conversational ability comes at the expense of some predictive power.\n",
        "\n",
        "**For this section, make sure that you have selected a TxGemma-Chat model variant.**"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7lh_qkyjjCkC",
      "metadata": {
        "id": "7lh_qkyjjCkC"
      },
      "source": [
        "### Ask questions in a multi-turn conversation\n",
        "\n",
        "This section demonstrates prompting TxGemma for conversational use, adhering to the [Gemma chat template](https://ai.google.dev/gemma/docs/core/prompt-structure).\n",
        "\n",
        "In this example, first prompt the model to answer a question regarding a predictive task using the TDC format. Then, ask a follow-up question requesting the model to provide its reasoning for the predicted answer."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "B59raCIANqb3",
      "metadata": {
        "id": "B59raCIANqb3"
      },
      "source": [
        "**Run the model directly**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3ZmG3JPbHd3Q",
      "metadata": {
        "id": "3ZmG3JPbHd3Q"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": "\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**User:**\n\nInstructions: Answer the following question about drug properties.\nContext: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.\nQuestion: Given a drug SMILES string, predict whether it\n(A) does not cross the BBB (B) crosses the BBB\nDrug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\nAnswer:\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**TxGemma:**\n\n(B)\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**User:**\n\nExplain your reasoning based on the molecule structure.\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**TxGemma:**\n\nHere's the breakdown of why the drug with SMILES CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21 likely crosses the BBB:\n\n* **Small Size:** The molecule is relatively small, which is a general favorable characteristic for BBB penetration. Larger molecules have a harder time squeezing through the tight junctions between brain endothelial cells.\n* **Lipophilicity (Hydrophobicity):** The presence of multiple carbon and hydrogen atoms in the benzene rings and alkyl chains makes the molecule predominantly lipophilic (fat-loving).  Lipophilicity is a key determinant of BBB permeability. The more lipophilic a drug, the easier it crosses the lipid-rich cell membranes of the BBB.\n* **Lack of Charged Groups:**  The molecule lacks large, charged groups (like carboxyl or amine groups).  Charged groups tend to be repelled by the lipid bilayer of the BBB, making it harder to cross. \n* **Absence of Specific BBB Targets:** While some drugs have specific transporters that help them across the BBB, this molecule doesn't appear to have any obvious targeting groups.\n\n**In summary, the drug's small size, lipophilic nature, lack of significant charge, and absence of specific BBB-targeting groups suggest that it likely possesses a good ability to cross the blood-brain barrier.**\n\n**Important Note:** This is a general prediction based on structural features. Actual BBB permeability is complex and can be influenced by many factors. Experimental validation is always necessary to confirm drug penetration ability. \n\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from IPython.display import display, Markdown\n",
        "\n",
        "questions = [\n",
        "    TDC_PROMPT,  # Initial question is a predictive task from TDC\n",
        "    \"Explain your reasoning based on the molecule structure.\"\n",
        "]\n",
        "\n",
        "messages = []\n",
        "\n",
        "display(Markdown(\"\\n\\n---\\n\\n\"))\n",
        "for question in questions:\n",
        "    display(Markdown(f\"**User:**\\n\\n{question}\\n\\n---\\n\\n\"))\n",
        "    messages.append(\n",
        "        { \"role\": \"user\", \"content\": question },\n",
        "    )\n",
        "    # Apply the tokenizer's built-in chat template\n",
        "    inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\")\n",
        "    outputs = model.generate(input_ids=inputs.to(\"cuda\"), max_new_tokens=512)\n",
        "    response = tokenizer.decode(outputs[0, len(inputs[0]):], skip_special_tokens=True)\n",
        "    display(Markdown(f\"**TxGemma:**\\n\\n{response}\\n\\n---\\n\\n\"))\n",
        "    messages.append(\n",
        "        { \"role\": \"assistant\", \"content\": response},\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5B9p8mRaNssg",
      "metadata": {
        "id": "5B9p8mRaNssg"
      },
      "source": [
        "**Run with the `pipeline` API**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Y-N8-xHgNwGC",
      "metadata": {
        "id": "Y-N8-xHgNwGC"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": "\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**User:**\n\nInstructions: Answer the following question about drug properties.\nContext: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.\nQuestion: Given a drug SMILES string, predict whether it\n(A) does not cross the BBB (B) crosses the BBB\nDrug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21\nAnswer:\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**TxGemma:**\n\n(B)\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**User:**\n\nExplain your reasoning based on the molecule structure.\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/markdown": "**TxGemma:**\n\nHere's the breakdown of why the drug with SMILES CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21 likely crosses the BBB:\n\n* **Small Size:** The molecule is relatively small, which is a general favorable characteristic for BBB penetration. Larger molecules have a harder time squeezing through the tight junctions between brain endothelial cells.\n* **Lipophilicity (Hydrophobicity):** The presence of multiple carbon and hydrogen atoms in the benzene rings and alkyl chains makes the molecule predominantly lipophilic (fat-loving).  Lipophilicity is a key determinant of BBB permeability. The more lipophilic a drug, the easier it crosses the lipid-rich cell membranes of the BBB.\n* **Lack of Charged Groups:**  The molecule lacks large, charged groups (like carboxyl or amine groups).  Charged groups tend to be repelled by the lipid bilayer of the BBB, making it harder to cross. \n* **Absence of Specific BBB Targets:** While some drugs have specific transporters that help them across the BBB, this molecule doesn't appear to have any obvious targeting groups.\n\n**In summary, the drug's small size, lipophilic nature, lack of significant charge, and absence of specific BBB-targeting groups suggest that it likely possesses a good ability to cross the blood-brain barrier.**\n\n**Important Note:** This is a general prediction based on structural features. Actual BBB permeability is complex and can be influenced by many factors. Experimental validation is always necessary to confirm drug penetration ability.\n\n---\n\n",
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from IPython.display import display, Markdown\n",
        "\n",
        "questions = [\n",
        "    TDC_PROMPT,  # Initial question is a predictive task from TDC\n",
        "    \"Explain your reasoning based on the molecule structure.\"\n",
        "]\n",
        "\n",
        "messages = []\n",
        "\n",
        "display(Markdown(\"\\n\\n---\\n\\n\"))\n",
        "for question in questions:\n",
        "    display(Markdown(f\"**User:**\\n\\n{question}\\n\\n---\\n\\n\"))\n",
        "    messages.append(\n",
        "        { \"role\": \"user\", \"content\": question },\n",
        "    )\n",
        "    outputs = pipe(messages, max_new_tokens=512)\n",
        "    messages = outputs[0][\"generated_text\"]\n",
        "    response = messages[-1][\"content\"].strip()\n",
        "    display(Markdown(f\"**TxGemma:**\\n\\n{response}\\n\\n---\\n\\n\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "PHTxQttKYNpa",
      "metadata": {
        "id": "PHTxQttKYNpa"
      },
      "source": [
        "# Next steps\n",
        "\n",
        "Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[TxGemma]Quickstart_with_Hugging_Face.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
